//
// Created by Xing on 2022/3/26.
//

#include <vector>

#include "caffe/filler.hpp"
#include "caffe/layers/reduce_l2_layer.hpp"
#include "caffe/util/math_functions.hpp"

namespace caffe {
    template<typename Dtype>
    void ReduceL2Layer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*> &bottom, const vector<Blob<Dtype>*> &top) {
        ReduceL2Parameter param = this->layer_param_.reduce_l2_param();
        axis_ = param.axis();
        keep_dim_ = param.keep_dim();
        CHECK_EQ(axis_, 1) << " reduce axis only support 1.";
    }

    template<typename Dtype>
    void ReduceL2Layer<Dtype>::Reshape(const vector<Blob<Dtype>*> &bottom, const vector<Blob<Dtype>*> &top) {
        CHECK_EQ(bottom.size(), top.size());
        CHECK_EQ(bottom.size(), 1);
        CHECK_LT(axis_, bottom[0]->shape().size());
        num_ = bottom[0]->shape(0);
        vector<int> output_shape;
        for (int i=0; i<bottom[0]->shape().size(); ++i) {
            if (axis_ == i) {
                if (keep_dim_) {
                    output_shape.emplace_back(1);
                }
                continue;
            }
            output_shape.emplace_back(bottom[0]->shape(i));
        }
        top[0]->Reshape(output_shape);
    }

    template<typename Dtype>
    void ReduceL2Layer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*> &bottom, const vector<Blob<Dtype>*> &top) {
        const Dtype* bottom_data = bottom[0]->cpu_data();
        Dtype* top_data = top[0]->mutable_cpu_data();
        const int CHW = bottom[0]->count(1);
        const int HW = bottom[0]->count(2);
        for (int n=0; n<num_; ++n) {
            for (int i=0; i<HW; ++i) {
                float sum=0;
                for (int c=0; c<bottom[0]->shape(1); ++c) {
                    sum += (bottom_data[n*CHW + c*HW + i] * bottom_data[n*CHW + c*HW + i]);
                }
                top_data[n*HW + i] = std::sqrt(sum);
            }
        }
    }
    INSTANTIATE_CLASS(ReduceL2Layer);
    REGISTER_LAYER_CLASS(ReduceL2);
}  // namespace caffe